{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 第 2 章：PyTorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'1.0.1'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 导入PyTorch库\n",
    "import torch\n",
    "import torchvision\n",
    "\n",
    "# 查看安装的PyTorch版本\n",
    "torch.__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 张量 Tensor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 创建 Tensor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "创建一个随机初始化的 Tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-0.3249,  0.4365],\n",
      "        [ 0.3976, -0.4804]])\n"
     ]
    }
   ],
   "source": [
    "x = torch.randn(2,2)\n",
    "print(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "直接把 Python 列表构建成 Tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1, 2],\n",
      "        [3, 4]])\n"
     ]
    }
   ],
   "source": [
    "x = torch.tensor([[1, 2], [3, 4]])\n",
    "print(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "创建一个全零 Tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0., 0.],\n",
      "        [0., 0.]])\n"
     ]
    }
   ],
   "source": [
    "x = torch.zeros(2,2)\n",
    "print(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "基于现有的 Tensor 创建新的 Tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0., 0.],\n",
      "        [0., 0.]])\n",
      "tensor([[1., 1.],\n",
      "        [1., 1.]])\n"
     ]
    }
   ],
   "source": [
    "x = torch.zeros(2,2)\n",
    "y = torch.ones_like(x)\n",
    "print(x)\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "指定 Tensor 数据类型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1, 1],\n",
      "        [1, 1]])\n"
     ]
    }
   ],
   "source": [
    "x = torch.ones(2, 2, dtype=torch.long)\n",
    "print(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tensor 的数学运算"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "两个 Tensor 相加"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[2., 2.],\n",
      "        [2., 2.]])\n"
     ]
    }
   ],
   "source": [
    "x = torch.ones(2,2)\n",
    "y = torch.ones(2,2)\n",
    "z = x + y\n",
    "print(z)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "也可以使用torch.add()实现Tensor相加："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[2., 2.],\n",
      "        [2., 2.]])\n"
     ]
    }
   ],
   "source": [
    "x = torch.ones(2,2)\n",
    "y = torch.ones(2,2)\n",
    "z = torch.add(x, y)\n",
    "print(z)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "还可以使用._add()实现替换："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[2., 2.],\n",
      "        [2., 2.]])\n"
     ]
    }
   ],
   "source": [
    "x = torch.ones(2,2)\n",
    "y = torch.ones(2,2)\n",
    "y.add_(x)\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Tenosr乘法有两种形式，第一种是对应元素相乘："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 1,  4],\n",
       "        [ 9, 16]])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.tensor([[1, 2], [3, 4]])\n",
    "y = torch.tensor([[1, 2], [3, 4]])\n",
    "x.mul(y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "第二种更常用的是矩阵相乘："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 7, 10],\n",
       "        [15, 22]])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.tensor([[1, 2], [3, 4]])\n",
    "y = torch.tensor([[1, 2], [3, 4]])\n",
    "x.mm(y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tensor 与 NumPy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 导入numpy\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Tensor to NumPy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'torch.Tensor'>\n",
      "<class 'numpy.ndarray'>\n"
     ]
    }
   ],
   "source": [
    "a = torch.ones(2,2)\n",
    "b = a.numpy()\n",
    "print(type(a))\n",
    "print(type(b))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "此时，如果Tensor发生改变，对应的NumPy数组也有相同的变化。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[2., 2.],\n",
      "        [2., 2.]])\n",
      "[[2. 2.]\n",
      " [2. 2.]]\n"
     ]
    }
   ],
   "source": [
    "a.add_(1)\n",
    "print(a)\n",
    "print(b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### NumPy to Tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'numpy.ndarray'>\n",
      "<class 'torch.Tensor'>\n"
     ]
    }
   ],
   "source": [
    "a = np.array([[1, 1], [1, 1]])\n",
    "b = torch.from_numpy(a)\n",
    "print(type(a))\n",
    "print(type(b))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "如果NumPy数组发生改变，对应的Tensor也有相同的变化。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[2 2]\n",
      " [2 2]]\n",
      "tensor([[2, 2],\n",
      "        [2, 2]], dtype=torch.int32)\n"
     ]
    }
   ],
   "source": [
    "np.add(a, 1, out=a)\n",
    "print(a)\n",
    "print(b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CUDA Tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = torch.ones(2,2)\n",
    "# 检查是否可以使用GPU\n",
    "if torch.cuda.is_available():\n",
    "    a_cuda = a.cuda()\n",
    "    print(a_cuda)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "因为我们安装的是CPU版本的PyTorch，所以这里不会执行if语句。如果安装了GPU，a_cuda的打印结果如下：\n",
    "\n",
    "```\n",
    "tensor([[1., 1.],\n",
    "        [1., 1.]], device='cuda:0')\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 自动求导 autograd"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义 Tensor x，设置参数`tensor.requries_grad=True`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n",
      "True\n"
     ]
    }
   ],
   "source": [
    "x = torch.ones(2, 2, requires_grad=True)\n",
    "y = torch.ones(2, 2, requires_grad=True)\n",
    "print(x.requires_grad)\n",
    "print(y.requires_grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 当输出是标量时"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义输出$z=\\frac14\\sum_ix_i+y_i$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(2., grad_fn=<MeanBackward1>)\n"
     ]
    }
   ],
   "source": [
    "z = x + y\n",
    "z = z.mean()\n",
    "print(z)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "反向传播"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "z.backward()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "计算$\\frac{\\partial z}{\\partial x}$和$\\frac{\\partial z}{\\partial y}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0.2500, 0.2500],\n",
      "        [0.2500, 0.2500]])\n"
     ]
    }
   ],
   "source": [
    "print(x.grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0.2500, 0.2500],\n",
      "        [0.2500, 0.2500]])\n"
     ]
    }
   ],
   "source": [
    "print(y.grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 当输出是多维张量时"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义输出 $z=2x+3y$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[5., 5.],\n",
      "        [5., 5.]], grad_fn=<AddBackward0>)\n"
     ]
    }
   ],
   "source": [
    "x = torch.ones(2, 2, requires_grad=True)\n",
    "y = torch.ones(2, 2, requires_grad=True)\n",
    "z = 2 * x + 3 * y\n",
    "print(z)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "反向传播"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "z.backward(torch.ones_like(z))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "计算$\\frac{\\partial z}{\\partial x}$和$\\frac{\\partial z}{\\partial y}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[2., 2.],\n",
      "        [2., 2.]])\n"
     ]
    }
   ],
   "source": [
    "print(x.grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[3., 3.],\n",
      "        [3., 3.]])\n"
     ]
    }
   ],
   "source": [
    "print(y.grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 禁止自动求导"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n",
      "True\n",
      "False\n"
     ]
    }
   ],
   "source": [
    "print(x.requires_grad)\n",
    "print((2 * x).requires_grad)\n",
    "\n",
    "with torch.no_grad():\n",
    "    print((2 * x).requires_grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 神经网络包 nn 和优化器 optim"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### torch.nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "class net_name(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(net_name, self).__init__()\n",
    "        self.fc = nn.Linear(1, 1)\n",
    "        # 其它层\n",
    "    \n",
    "    def forward(self, x):\n",
    "        out = self.fc(x)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "新建一个该模型的对象"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = net_name()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### torch.optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.optim as optim"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "计算预测值与真实值的均方误差"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```\n",
    "criterion = nn.MSELoss()\n",
    "loss = criterion(output, target)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用随机梯度下降（SGD)优化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```\n",
    "optimizer = optim.SGD(net.parameters(), lr=0.01)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "单次迭代对应的代码为"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```\n",
    "optimizer.zero_grad()    # 梯度清零\n",
    "output = net(input)\n",
    "loss = criterion(output, target)\n",
    "loss.backward()\n",
    "optimizer.step()    # 完成更新\n",
    "\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PyTorch 线性回归"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 创建数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "# y=3x+10，后面加上torch.randn()函数制造噪音\n",
    "x = torch.unsqueeze(torch.linspace(-1, 1, 50), dim=1)\n",
    "y = 3 * x + 10 + 0.5 * torch.randn(x.size())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "显示数据分布"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<Figure size 1000x600 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.rcParams['font.sans-serif']=['SimHei']\n",
    "plt.rcParams['axes.unicode_minus']=False\n",
    "plt.rcParams['figure.figsize'] = (10.0, 6.0) # set default size of plots\n",
    "\n",
    "plt.scatter(x.numpy(), y.numpy())\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 定义模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LinearRegression(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(LinearRegression, self).__init__()\n",
    "        self.fc = nn.Linear(1, 1)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        out = self.fc(x)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = LinearRegression()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 定义 loss 和优化函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义loss和优化函数\n",
    "criterion = nn.MSELoss()\n",
    "optimizer = optim.SGD(model.parameters(), lr=5e-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch[20/1000], loss: 58.123623\n",
      "Epoch[40/1000], loss: 39.301388\n",
      "Epoch[60/1000], loss: 26.664478\n",
      "Epoch[80/1000], loss: 18.171267\n",
      "Epoch[100/1000], loss: 12.455207\n",
      "Epoch[120/1000], loss: 8.601423\n",
      "Epoch[140/1000], loss: 5.997334\n",
      "Epoch[160/1000], loss: 4.232640\n",
      "Epoch[180/1000], loss: 3.032410\n",
      "Epoch[200/1000], loss: 2.212354\n",
      "Epoch[220/1000], loss: 1.648835\n",
      "Epoch[240/1000], loss: 1.258876\n",
      "Epoch[260/1000], loss: 0.986709\n",
      "Epoch[280/1000], loss: 0.794804\n",
      "Epoch[300/1000], loss: 0.657871\n",
      "Epoch[320/1000], loss: 0.558824\n",
      "Epoch[340/1000], loss: 0.486085\n",
      "Epoch[360/1000], loss: 0.431788\n",
      "Epoch[380/1000], loss: 0.390558\n",
      "Epoch[400/1000], loss: 0.358707\n",
      "Epoch[420/1000], loss: 0.333685\n",
      "Epoch[440/1000], loss: 0.313713\n",
      "Epoch[460/1000], loss: 0.297540\n",
      "Epoch[480/1000], loss: 0.284271\n",
      "Epoch[500/1000], loss: 0.273265\n",
      "Epoch[520/1000], loss: 0.264049\n",
      "Epoch[540/1000], loss: 0.256270\n",
      "Epoch[560/1000], loss: 0.249662\n",
      "Epoch[580/1000], loss: 0.244020\n",
      "Epoch[600/1000], loss: 0.239182\n",
      "Epoch[620/1000], loss: 0.235021\n",
      "Epoch[640/1000], loss: 0.231432\n",
      "Epoch[660/1000], loss: 0.228331\n",
      "Epoch[680/1000], loss: 0.225646\n",
      "Epoch[700/1000], loss: 0.223320\n",
      "Epoch[720/1000], loss: 0.221302\n",
      "Epoch[740/1000], loss: 0.219550\n",
      "Epoch[760/1000], loss: 0.218029\n",
      "Epoch[780/1000], loss: 0.216707\n",
      "Epoch[800/1000], loss: 0.215558\n",
      "Epoch[820/1000], loss: 0.214558\n",
      "Epoch[840/1000], loss: 0.213690\n",
      "Epoch[860/1000], loss: 0.212934\n",
      "Epoch[880/1000], loss: 0.212276\n",
      "Epoch[900/1000], loss: 0.211704\n",
      "Epoch[920/1000], loss: 0.211207\n",
      "Epoch[940/1000], loss: 0.210774\n",
      "Epoch[960/1000], loss: 0.210397\n",
      "Epoch[980/1000], loss: 0.210070\n",
      "Epoch[1000/1000], loss: 0.209784\n"
     ]
    }
   ],
   "source": [
    "num_epochs = 1000    # 遍历整个训练集的次数\n",
    "for epoch in range(num_epochs):\n",
    "    # forward\n",
    "    out = model(x)    #前向传播\n",
    "    loss = criterion(out, y)    #计算loss\n",
    "    # backward\n",
    "    optimizer.zero_grad()    #梯度归零\n",
    "    loss.backward()          #反向传播\n",
    "    optimizer.step()         #更新参数\n",
    " \n",
    "    if (epoch+1) % 20 == 0:\n",
    "        print('Epoch[{}/{}], loss: {:.6f}'.format(epoch+1, num_epochs, loss.detach().numpy()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 模型测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXEAAAD6CAYAAABXh3cLAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3XmczeX7x/HXbYxMQ5ZI1qxJUtREZWmSklLJ2qbSr5RUKlskLYiiRSVSkqxl+dJXKZUGETX2IluI8ZXIWCfGzP374zMzGc6ZOefMWWfez8fDozlnPud8rnOaueY+9+e6r9tYaxERkchUKNQBiIiI75TERUQimJK4iEgEUxIXEYlgSuIiIhFMSVxEJIIpiYuIRDAlcRGRCKYkLiISwQoH+gRlypSxVatWDfRpRETylRUrVuyz1pbN7biAJ/GqVauSmJgY6NOIiOQrxpgdnhyn6RQRkQimJC4iEsGUxEVEIljA58RdSU1NZdeuXfzzzz+hOH1EKVq0KJUqVSI6OjrUoYhIGApJEt+1axfFixenatWqGGNCEUJEsNayf/9+du3aRbVq1UIdjoiEoZBMp/zzzz+ce+65SuC5MMZw7rnn6hOLiLgVkpE4oATuIb1PIoE3e1USw7/eyO7kFCqUjKF3y9q0aVAx1GF5RBc2PXTw4EE82cru+PHjpKenByEiEfGH2auS6DdrHUnJKVggKTmFfrPWMXtVUqhD80iBTeJ//PEHbdu2zbo9atQodu/ezcqVK5k0adIZx995552MGDHijPtTU1P59ttvSUhIICEhgS5dutCnT5+s299++23Wsbt37852zvvuu4/ly5czbdo0P786EfHU8K83kpKalu2+lNQ0hn+9MUQReadAJvETJ05QuHBhYmJiSE1NpWfPnlSqVIlXXnmFAwcOsGXLlmzHDx8+nGbNmrFx40a++uqrbN9LT09n165d7NmzJ+tfuXLlsr7etWtX1rGjRo3i6aefZtu2bQBER0dTqlQptm/fTlJSZPzVF8lvdieneHV/uAnZnHgozZ49m/fee4/NmzczYsQIfvnlF4YPH86wYcNo3bo1hQr9+7ft9ddfZ+vWrYwZM4YTJ05wzz33sHbtWp5++mmio6M566yzeOCBB7jzzjvZt28fv/76K6mpqZx11lkUK1aM2bNnA7BhwwYOHz6cdfzzzz/P5s2b6devHyVLlmTQoEGMGTMmVG+JSIFVoWQMSS4SdoWSMSGIxnuhT+JPPQWrV/v3OevXh7fecvvtjh070rx5c3r16kXfvn1ZsmQJhQoV4ocffuD7778nKiqK33//ne7duxMbG0u9evUYPHgwAHXr1mXr1q3Ur1+fMWPG0LRpUwC2b9/OsGHDsp2nb9++WV9PnjyZpUuXMn36dBISEvjwww85//zzefLJJ2ncuDFr1qzhjz/+oEqVKv59L0QkR71b1qbfrHXZplRioqPo3bJ2CKPyXOiTeIisXbuWhQsX8t5772Xdl56ezrFjxzjrrLOIjY2lf//+XHLJJWzbti1rdJ6enk6tWrU4evQoxYsXz3psWloae/bsyXaOUy9wDh48mLZt2/LNN99Qs2ZNqlatyk8//cRzzz1HhQoVqFy5Mg8++GCAX7WInC6zCiVSq1OMJxUXeREXF2dP72K4YcMG6tSpE9Dz5mTixIn8+OOPHD16lAkTJtC6dWu6d+/O8uXLKVGiBFWqVKFdu3akp6dzySWXUKFChWyPT0lJYcmSJdnuu+iiizj//POz3bdz5062bt0KwMaNG7nlllu49957qVWrFuXKlcNay5dffknDhg1ZtGgRo0ePdhlvqN8vEQk+Y8wKa21cbscVyJF4586dadWqFb169QKcZPv+++8zZcoU4uPjeeihhwAoVKgQ559/frYKE4CrrrrqjOeMjY2lffv22e776KOPsr4uVaoUI0aMoHr16hQvXpxq1aqRlpZG9+7d+e2335g5c6a/X6aIFAAFMokDWTXfaWlptGzZkr59+9K9e3c6dOjApEmTuPbaa6lduzZr166lRYsW2R574MCBrK9PnjyJMYYiRYpQv379bMcVLVqU9PR00tPTWblyJZMnTyYtLY3SpUszbNgwBgwYQNWqVbnwwgvZvHkzmzdvPuMPgYhITgpkEk9PT6dTp0507NiRqKgoHn74Ya699lq6detG9+7diY+PZ+nSpdSqVYt69eqdMRK/4oorsr6eMmUKH3/8MbGxsbz44ovZjjv77LO54YYb6NixI23btqVOnTpccMEF7Ny5k+bNm/PKK6/QunVrXnrpJa6//nqNxkXEaz7NiRtjygN1geXW2sM5HRuOc+LgjMQzl7Rba9mzZw/ly5cPyfnB6SdTtGhRl8eGw/slIsHl6Zy4R4t9jDHljDGLM76+EPgUaAwsNMYUyVOkIXJqAjXGBDWBn35+wG0CFxHJSa7TKcaYUsAEIDbjrkuBLtbarcaYekA1IDLWp4qI5DOejMTTgE7AIQBr7QxghzHmFqAUsCWHx4qISADlmsSttYestQdPu7sY0BHYAZwxqW6M6WqMSTTGJP7111/+iVRERM7gUwMsa22ytfZ+IBq40sX3x1pr46y1cWXLls1rjCGRlpaGtZbjx49z5MiRrPuttZw8eTKEkYmI/MvrEkNjzGhgqrV2EVASSPZ7VEEybtw4RowYQcWKzvLaUytE0tLSmDlzJlu3buXjjz9m1KhRAMybN4+ZM2cybty4bM81cOBArrvuOr799luKFy9O9+7dad++PV9++SVRUVFZx508eZKOHTsya9Ys7r77bnbv3p31vWLFijF37txAv2wRyUd8qRN/DZhojLHAfGttxF7UjIqKokePHjz66KMANGrUiMWLF2dVjhw6dIjVq1dTq1Yt/vzzT8qVK8eECRN49dVX+fzzz2natCmlSpXiyJEjnHPOOfz444/s3buXPXv2sGPHDmJjY4mKisrqoWKMoXDhwpw4cQJwepEnJCRkxaOFPiLiLY+TuLU2PuO/24AmgQrIlUBvndSrVy+WLVtGcnIy1113HcePH6dDhw7UrFmTkSNHcvHFFzN69GjGjx9PkyZNqFq1Kps2beKxxx5j6tSpHDx4kP379/Puu+9Sv359rrnmGt599122bNlCs2bN2LJlC7Nnz2bnzp2MHTuW1atX0759e3bu3El8fDxpaWkYY7L1HhcR8UTYr9jM3Dops01k5tZJgN8SeXJyMvfddx/nnXceAKtWreLQoUPcdtttTJkyhQkTJnDHHXfw9NNPs3nzZt577z1Kly7NkSNHmDFjBk2aNGH79u307NmTjRs3smfPHtauXcuQIUOoWbMm77//Pg0bNqRhw4a0a9eO1q1b89hjjzF27FimTZvG1KlTAbjrrrv88npEJMTS0uCUadRACvskntPWSXlN4idPniQ6OhprLaVLl6ZMmTIAnHPOORw+7CxEPXLkCMnJyVSuXJlJkyZRqlSprOmW48ePU6hQIf78809eeOEFJk6cyDPPPEN0dDT9+/dnxYoVFClShOrVq2edc+HChSxevJg2bdrw22+/0aJFC/bu3QvA+++/z913303Xrl3z9LpEJESOHIE33oApU2DlSjj77ICfMuyTeCC3TtqzZw9169alVq1ajBkzhqSkJAoVKkT58uW56667SEpK4q+//uKzzz6jRYsW/Pe//2Xw4MGUK1cOgN9//53du3ezY8cOBgwYwObNm1mzZg2rV69m27ZtWftptmrVCoBPPvmEGTNm0LRpUx566CEWLlzIxIkTmTFjBqA5cZGIdfIkjB8PAwfCnj3Qvj0cPqwkDoHdOmnNmjV06NCBSy+9lAYNGrBt2zaKFi2KMYZOnTpRtGhRRo8eTbNmzXjzzTeJjY2lb9++Wa1qGzVqBEDjxo3p2LEjy5Yto1WrVtStW5fo6Gguv/xyZs+ezcCBAwHo0KEDnTt35tZbb+WPP/6gTJky1KhRg8qVKwPQo0cP7bUpEkmshS++gL59Yf169l8WR782/fmmRHUqjP+V3i1PBnxzibDfKLl3y9rERGefW/LH1klHjhxh48aNxMbG8sADD1ChQoWsKpLixYvTtm1bUlJSGDduHJMnT+bYsWMcOnSI4cOHEx8fT3x8PH/++WfW8z366KNcf/31DB48mF9++YVt27bx66+/ctZZZ7Fy5Uon7pgYjDEYY/j+++9p3rw5V155JQkJCSQkJGRt9SYiESAxEa67Dm69FVJTWT7iA5q0fpn5Japj+ff63exVgR2Yhf1IPFBbJy1atIi7776bdevWMWrUKA4fPsyIESOYNGkS11xzDSVLlmTBggUUKVKE22+/HXCmQ3r37p01Er/88ssBp794t27dqF69OsuWLWP9+vV06dKFESNGUK5cOdq3b8+UKVOoUaMGffv25cYbb2TRokWMGTOGHj16EB8fDzi7/4hImNu2Dfr3h2nToGxZePdd6NqVZ15fTMrJ7LMG/rp+l5MCuT2brzJXahYufObfvpMnT2bdb60lPT09a5HP6W1nT5Wamkp0dHSO543U90skX/n7bxg82EnahQvDM89Anz5wzjkAVHv2izN7kAAG2DbsFq9PF/bbs+WU2MKVq+Tt6nvGmGyrNHN6nbkl8ED/kRUR92avSmLk3HXcuOAzui+bTvETxzBdusBLL0HF7KPrQF6/y0lI5sSLFi3K/v37laByYa1l//796jUuEgKzV+xk6UsjmTjiPvoljCexwkXc/tC7zO5+ZgKHwF2/y01IRuKVKlVi165dqMNh7ooWLUqlSpVCHYZIwfLdd9S571Ha7N7CunI16NOqB0urOnvoupvjDtT1u9yEJIlHR0dTrVq1UJxaRMS9deuccsF584g95zyevLUX/63TDGv+nbTIaY1KmwYVA560Txf21SkiIgGXlOQs1Pn4Y+dC5YgRdD5Wl21H0844NNBz3N4K+zpxEZGAOXQIBgyAWrVg0iR4+mnYuhV69qRH63ohmeP2lkbiIlLwpKbC2LFOlclff8Hdd8OQIVC1atYhoZrj9paSuIjkO27bV1sL//kPPPssbN7srLgcPhyuuMLl87ib4w50e2xvKImLSJ6EU0LLjMdV++rSaxJp9sFrsHQp1K3r9Dxp1Qq8XK8SjPbY3lASFxGfhVtCgzPbV1f9O4m+CyfQbNNSKF8ePvwQ7r/fWXXph+eH4Cyvd0cXNkXEZzkltFDJLAE892gyL30zmm/GPUbT7at4vem9/HfGQhr/VY1qA76m8bAFPjWnCmR7bF9oJC4iPgu3hAZQ7WzDTd99Srdl04lJPc6U+q14u/GdnCxzHsfnbc3zp4ZQLa93RyNxEfGZu8QVkoSWlgYffcQX7z5In0WfsPSCy7jx/95j4I3dOFqyDMbgl08NoVpe746SuIj4LCwSmrUwbx7Urw//93/E1KjGonGzePnBIWw7txIVS8YwtG09ko+luny4t58a2jSoyNC29ahYMgYDWc8f1tUpxphywAxrbVNjTBXgEyAd2AI8YtXJSqRACnkt9apV0Ls3fPcd1KgB06dDu3Y0M4Ylpx06/OuNfpsGCcXyendyTeLGmFLABCA2465HgG7W2g3GmHlAPWBt4EIUkXAWklrqHTuclZaTJsG558Lbb8Mjj0CRIm4f0rtl7WyVNBCeKzC95clIPA3oBMwBsNY+d8r3zgX2BSAuEYlgvpQeepT0DxyAV15xknahQtCvn9OwqkSJXGMK+aeGAMk1iVtrD8GZGxsYYzoBv1prd5/+GGNMV6ArQJUqVfwSqIhEDm9rqXNN+sePw6hRzs46yclOnfegQeBlm+ZwmgbxF58ubBpjqgO9gKdcfd9aO9ZaG2etjStbtmxe4hORCORt6aG7pD9i3gaYOhUuugh69oSGDWH1ahg/3usEnl95XSeeMUc+FXjQWnvQ/yGJSKTztpbaVXK/6o+19Pt+POzZDJddBvPnww03+D3WSOfLYp9ngSrAOxlTLC9Yaxf6NSoRiWjeXkQ8NenX+msHfRd+TIutP7OnxHnwySdwzz3OHPgpwq1nS6iEZLd7Ecn/vEmys1cl8caEhTya8Amd1n7D0SIxjG3ciQsH9+O2q2u6PN7VH4lQ1mv7m6e73SuJi0hoHT4MI0Zw8rXh2NRUJta/mRmtHqBr24ZuE3LjYQtcTtdULBnDkmebBzrioPA0iat3ioiERmqq01HwxRdh714Kd+oEQ4bwYI0aPJjLQ8OxZ0uoaNm9iASXtTBnDtSrB489BrVrw7JlMG2as+rSA2HVsyXElMRFJHiWLYNmzaBNG2czhjlzYOFCaNTIq6cJi54tYUJJXEQCb+tW6NgRrr7a2RZtzBhYtw5uu83rnXUg/JpQhZLmxEUkcPbtc1ZWjh4N0dHwwgvOop3ixfP81Plx9aUvlMRFxP9SUmDkSBg6FI4cgYceci5gli8f6sjyHSVxEfGftDSns+CAAbBrF9x6KwwbBhdf7PNTalFPzpTERcQ/5s+HPn1gzRq48konmV97bZ6eMhw3Yg43urApInmzZg20bOn8O3TIKRVctizPCRzCcyPmcKMkLiK+2bnTaQnboAH8/DO88QZs2ACdOp3R58RXWtSTO02niESQsJgfPnjQmed+6y1n4U6vXs7mDKVK+f1U4bazfDjSSFwkQmTODyclp2D5d3549qqk4ARw4oSzo06NGk4S79ABNm6E114LSAIHLerxhJK4SIQI2fywtfDZZ1CnDvTo4ewqv3Kl0yL2ggsCemot6smdplNEIoQv88N5nn5ZvNiZLvnpJ6fXyZdfwk03+bTK0lda1JMzjcRFIoS3TZ/yNP3y229Of5NmzSApCT76CFatglatgprAJXdK4iIRwtv5YZ+mX/bsgW7d4JJLYMECZ2f5TZugSxeIinL/OAkZTaeIRIjMKQVPp0e8mn45csQpEXztNWdn+W7dYOBA0EbnYU9JXCSCeDM/7FF53smTzs7xAwc6o/D27Z3Rd61a/gpZAkzTKSL5VI7TL9bC3LnOLvJduzplg0uXwvTpSuARRiNxkXzK7fRL2v/gunuczRguvBBmzcrapCEYi4nCYsFSPqIkLuJGfkg22aZftm2D/r2c3iZly8KoUfDww06fb4LTbEoNrfzPo+kUY0w5Y8ziU27XMcbMCVxYIqEV8tWR/vT33/DMM85elnPmOG1it2xx9rfMSOAQnMVEamjlf7mOxI0xpYAJQGzG7RrAcKBYYEMTCZ2ckk3EjBj/+Qfeece5UHnokFMm+PLLUKFCxqeMxGyfMoLRbEoNrfzPk5F4GtAJOJRx+zDQLmARiYQBX1dHNh62gGrPfkHjYQtCN2pPT3d6edeu7fT3vuYap13shx9mJXBXnzJKnh3t8un82WxKu9T7X65J3Fp7yFp78JTbe621x3N6jDGmqzEm0RiT+Ndff/kjTpGgCurqSH/69luIi4POnaFMGfjuO/jiC2fxTgZ3nzKsJeDNptTQyv8CUmJorR1rrY2z1saV1WIBiUBBWR3pT2vXOkvib7jBmQOfPNnp8d28+RmHuvs0cTAlNeDNptTQyv9UnSLigj9XR7qrcvFL9cuuXc5CnY8/hhIlYPhwePxxKFrU7UNyWgQUjGZTamjlX0riIm74Y3VkiZholyV1iTv+ZuaKJN9L7Q4dgldfhTffdDYnfuYZ6N8fSpfO9aG9W9bOFhNoSiOSeTydYq2Nz+m2SEHmbvrFGFxOs0xdvtO36ZcTJ+Ddd50Vlq+8Anfc4XQcHDHCowQOvk1phM1FWzmDRuIifuBu+uXpT1e7PD7NWpf3u61+sdZZWfnss06Nd3y8M3USF+dzvJ5+ytACnfCmJC7iJ64S4/CvN7qcZokyxmUid1n9smQJ9O4NP/4IF1/s9Dy5+eag9fXOFzXz+ZgaYIkEkLtplrsaVc69+mXTJmjbFpo0ge3b4YMPnHrvW24J6sYMWqAT3jQSFwmgnKpc4i4o7bo6Ze9eZ2Xl++87VSYvv+xcuIyNDclr0I7z4c1YN3Nz/hIXF2cTExMDeg6RfOHYMafa5NVXna8ffhhefBHKlQtpWKfPiYPzqUH13YFljFlhrc31oodG4iKhlpbm1HkPHAi7dzttYYcNc5bNhwFva+YluJTEJV+JqPax1sJXXzn9TX75Ba66Cj791JkDDzNaoBO+dGFT8o2w6V/iiZUroUULp8okJQU++8zZWScME7iENyVxyTdC3r/EE9u3w733whVXOJUmb78N69dDhw5BrTiR/EPTKZJvhHUp3IEDzgrLt9+GQoWgXz/o29fpdyKSB0rikm+EZSnc8ePONmiDB0NyMtx3HwwaBJUrhy4myVc0nSL5Rlj1qk5Ph6lT4aKLoGdPaNgQVq1yqlCUwMWPNBKXfCNsSuESEpxl8omJcNllMH++0+dbJACUxCVfCWkp3K+/Og2q5s51RtsTJsA990BUVO6PFfGRkrgUCAGtH//f/5yFOh99BMWKwdCh0KMHxARvLj6i6uPFr5TEJd8LWCvVw4edPt4jRkBqKjzxBAwY4OxtGURqFVuw6cKm5Ht+rx9PTYUxY6BmTac5VevWsGEDvPVW0BM4REh9vASMRuKS7/mtftxa+Pxzp75740Zo2tS53aiRH6L0XVjXx0vAaSQu+Z67OnGv6seXL4drr3WaUxkDc+bAwoUhT+Dgp9cnEUtJXPK9PNWPb90KHTs6zak2bXKmUdatg9tuC9gyeW/3swyr+ngJOk2nSL6XW/24y8qOymc5KytHj4boaHjhBejVy6k+CSBfLlKGTX28hIQ2hZAC7fSkeVbqcbqumsuTP88g+thReOghZ2OG8uWzPSZQCbPxsAUuWwdULBnDkmeb++UcEhn8uimEMaYcMMNa29QYEw3MAkoD46y1H+UtVJHQyazsKJSexh2/JtBz8UQqHN7HD3WupsnMcVCnTrbjA13Op4uU4q1c58SNMaWACUDmBn9PACustY2B9saY4gGMTySgdien0HTbSuZOeIrXv3yTvcVK0emuoXS+7bkzEjgEvpxPFynFW55c2EwDOgGHMm7HA59lfL0IyHW4LxKW1qxh2qwXmfjZQIodP8bjt/Xhjs6vs7xKPbdJM9AjZV2kFG/lOp1irT0EYP69Eh8LZF4u/xs4YxdXY0xXoCtAlSpV/BGniP/s3OmsrJw4kQbnlGDoDV0Zf2krThSOBnJOmoFud6uLlOItX6pTjgAxwEGgWMbtbKy1Y4Gx4FzYzEuAIn6TnOxsQPzWW87t3r0p0q8fdbYdpayHSbN3y9oud37350hZ+1mKN3xJ4iuAJsAM4DJgmV8jEvG3EyecUsFBg2D/fujc2fn6ggsAaNOgpMdJUyNlCTe+JPEJwJfGmKbAxcBy/4Yk4ifWwvTpzlZov/8O118Pw4dDgwZ5elqNlCWceJzErbXxGf/dYYy5AWc0PtBam5bjAyXfCuv2p4sXO4tzfvoJ6tWDefOgZUttRiz5jk8rNq21u/m3QkUKoFC3P3X7B+S335wGVZ9/DhUrOj2+77tPGzNIvqXeKeKTULY/zfwDkpScgsX5A/L6JwvZ1r4zXHIJfP89DBni9Drp0kUJXPI19U4Rn4RyZeGpf0DOPpHCwz/9h64/zeKstFR4rJuzy07ZsgGPQyQcKImLTwJdL53J1bTJ7uQUotLT6Lj2G57+YTLnHT3Alxdew/Br7+f7d7r69fwi4U5JXHwSjHppl/PuM9dy284VPP71h9Tav5PEinV49I7+rKxYh4pami4FkJK4+CQY9dKnz7tf+r9N9P/+I67a+QvbS1fkkTv683Wtq8EYLU2XAktJXHwW6HrpzPn1Ssl76LPoE27bsIh9Z5fg+Ru6ETe4D78s+B0TjuWNIkGkJC5B5U1t+UXRJ2j31QTuW/kFaYWiePvqToxt1I4S5c5lUMOq3N6wanCDFwlDSuISNB7Xlv/zD7zzDp+/NZhCR48wvV4L3mxyN38WL6NpE5HTqE5cgibX2vL0dJg0CWrXhj59iG7WhIXTvuadu/qyt3gZKpaMYWjbepo2ETmFRuISNDnWln/3HfTuDatWweWXw/jx0Lw5zQFtSibinkbiEjSuashr/7WdKf95CVq0cDoMTpoEP/8MzZW6RTyhJC5Bc+quNeUO7+PVL0fy5fgnueJ/m5zughs3wj33QCH9WIp4StMpkitvuxW6O75Ng4oUPnKY/QMH0fGHmUTZdLbd8xA1Rw6F0qWD+IpE8g8lccmRt90K3R1vUlO5/ae5tH7pJdi3D+66C4YMoWa1asF7MSL5kD63So687VZ4xvHWcu0vi2hwcxN44gmny+DPP8OUKaAELpJnGolLjrztVnjq/Zfv2kD/hI+IS9rApnOrwNy5cPPN2phBxI+UxCVH3nYrrFAyhiK/b6HPwgm02rSUP4uVps9NT7KsSWsW3XJDoMMVKXCUxPOBQG6T5lW3wr17mbjqEyrPmMTxwkV4vck9fHjlHRAby9CbL/ZLPCKSnZJ4hAv0NmkedSs8dgzefBNefZXqx47xe7t7ebJWa35Ni1FzKpEAM9bagJ4gLi7OJiYmBvQcBVnjYQtcTndULBnDkmcDvGAmLQ0mTIDnn4fdu6FNGxg2zFk2LyJ5YoxZYa2Ny+04r0fixphqwLvAOcBP1tqePsQnfhKSbdKsha++gj594JdfoFEj+PRTaNLE56cM5JSQSH7mS4nhq8Aga21ToJIxJt6/IYk3crrAGBArVzpL5G++GVJSYPp0+PHHPCfw0zc+7jdrHbNXJfkvbpF8ypckfiGwMuPrvUAJ/4Uj3jp1KXumgLRr3bEDOneGK66ANWvg7bdh/Xpo3z7PJYPe1qKLyL98ubA5A3jBGLMMuAnod/oBxpiuQFeAKlWq5ClAyVnAt0k7cACGDnWStjHQrx/07Qsl/Pe3OyRTQiL5hNdJ3Fo72BjTBOgNTLDWHnFxzFhgLDgXNvMcpeQoINukHT8O770HgwZBcjLcfz+8/DJUruzf8+B9LbqI/MvXZfergSrAG36MRcJBejpMmwZ16sAzz0DDhk6P7/HjA5LAIYhTQiL5kK914r2BN6y1x/wZjPiX1xUfCQnOxgyJiVC/PsyfDzcEfpVlwKeERPIxn5K4tfYFfwdS0AS6pC6nRUCQPWG+XBOun/Cm09ukcmUyQir7AAANOElEQVSn9vvee4Pa1zsgU0IiBYBWbIZAoFdZgvuKjxc//5XjJ9NJSU3jvMP7eXzeZOLXfUtqbCzRw4bBk09CjOaiRSKFkngI5FZS548RurvKjuSUVGKPH+Ppn2bx8M//oXBaGhMub82MVg/wZd87vH8xIhJSSuIh4C7BZo7I/TFCd1XxUTjtJHeunU+PH6ZQ9lgycy9qymvN7uOPUuUxqT68EBEJOW0KEQLuSueijPHbopdsFR/WcuOmH5k//nEGz3+P30tXpE3n13n89r78Uap8jjGJSHjTSDwE3LV3PT2BZ/Jl0UvmyP3LD/7Dw/8dzZW71nO4ak2WDfiILvvOJ+VkerZzq5xPJDJpJB4CbRpUZGjbelQsGYPB6TiYedsVn0bJW7bQ5pWnGDv6Ca5M3Q9jxlB88waueqoLQ9tdesa5VRkiEpk0Eg8RdyV1Hm/A4M6+fc4qy9GjIToaXngBevaE4sVzPbeIRB4l8TCSp0UvKSkwcqTT5+TIEXjoIXjxRShfPrBBi0hIKYmHGa9HyWlpMGkSDBgAu3bBrbc6GzNcrO3QRAoCzYlHsvnzndawDzwA558P338Pn3+uBC5SgCiJR6I1a+DGG6FlSzh0CKZOheXLIT4+1JGJSJApiUeSnTudlrANGjhNqt54AzZsgDvvDGqfExEJH5oTjwQHDzoXLEeOdPa37NXL2ZyhVCm/nkb7XIpEHiXxcHbihFMqOGgQ7N/vdBYcPBguuMDvpwpGUy4R8T99Bg9H1sJnnzkbMzz1lNPbe8UKmDgxIAkctM+lSKRSEg83ixfD1VdDp05w9tkwbx588w1cfnlAT6t9LkUik5J4uPjtN2jTBpo1cy5gjhsHq1fDTTfleTd5T7hb2q/GWCLhTUk81PbsgW7d4JJLYMECGDIENm+GBx+EqKjcH+8n2udSJDLpwqafeF3ZcfQovP46vPaas7P8o4/CwIFw3nn+eX4vaZ9LkcikJO4HXlV2nDzp7Bw/cKAzCm/b1ikfvPBC/zx/Hqgxlkjk0XSKH3hU2WGtsxHxZZdB165QvTosWQIzZ+aYwD1+fhEpkJTE/SDXyo7ERGje3GlOlZoKs2bBDz/ANdf45/lFpMDyejrFGFMKmAycB6yw1j7i96iCxF/zzK72swS4Ij0Z7roLpk2DsmVh1Ch4+GGnz7cfnl+VIyLiy0i8MzDZWhsHFDfGxPk5pqDInGdOSk7B8u888+xVSV4/1+mVHSVSDvNiwjg+fetBmDPHaRO7ZQs89pjXCdzV84MqR0TE4cuFzf3AJcaYkkBlYKd/QwqOnOaZvR2NZx4/cu46blzwGd2XTaf4iWOYLl3gpZegYt4uFqpyRETc8SWJ/wDcAjwJbAD+Pv0AY0xXoCtAlSpV8hJfwPh1njk9nTbrE2gz7jnYsQNuvhlefdWp/fYTVY6IiCu+TKe8ADxqrX0Z+A3ocvoB1tqx1to4a21c2bJl8xpjQPhtheKCBXDllU5zqnPPhe++gy++8GsCFxFxx5ckXgqoZ4yJAhoB1r8hBUee55l/+cUZcV9/vdNhcPJk+PlnpwpFRCRIfEniQ4GxwEGgNDDVrxEFSZsGFRnath4VS8ZggIolYxjatl7uUxZJSfB//+fUe//4Iwwf7vQ9uftubcwgIkHn9Zy4tfYnoG4AYgk6r+aZDx1ylsi/8YazOfFTT8Fzz0Hp0oENUkQkB1p2n5vUVHj/fafKZN8+p+57yBCoVi3UkYmIaMWmW9Y6S+Lr1oUnnnAuVP78M0yZogQuImFDSdyVpUuhcWNo395ZnDN3rlOFEheR65pEJB9TEj/Vpk3Qrp2TwLdvh7FjYc0auOWWoGzMICLiLSVxgL17oXt3uPhimD8fXn7Z2Zjh4YehsC4biEj4KtgZ6tgxePNNZ3XlsWNOi9gXXoBy5UIdmYiIRwpmEk9LgwkT4PnnYfduuP12GDYMLroo1JGJiHilYE2nWOvsHl+/vrNgp0oVZ3f52bOVwEUkIhWckfjKldC7t1NlUqMGTJ/uXMQM8AXLQO+NKSIFW/4fie/Y4TSnuuIKp9Lk7bdh/XqnfDAICdxfPctFRFzJv0n8wAFn5H3hhc6inX79YOtWZ+FOkSJBCUF7Y4pIoOW/6ZTjx51t0AYPhuRkuP9+GDQIKlUKeijaG1NEAi3/jMTT02HqVOcCZc+e0LAhrFoF48eHJIGDH3uWi4i4kT+SeEICNGrktIMtUQK+/hq++sppFxtC2htTRAItsqdT1q+Hvn2d3iaVKzu13/feG9C+3t5Um2hvTBEJtMhM4v/7n7Oyctw4KFbMWajz5JMQE9hpisxqk8yLlZnVJkCOiVxJW0QCJbKmUw4fdpJ3zZrw8cdOpcnWrc5oPMAJHFRtIiLhJzJG4qmp8OGH8OKLTrOqjh3hlVecRTtBpGoTEQk34Z3ErYU5c+DZZ2HjRmjaFD7/3LmIGUDu5r0rlIwhyUXCVrWJiIRKeE+nrF8Pd9zhrKycMwcWLgxKAne3ylLVJiISbsJ7JF63rtPf+7rrgtbXO6d57yXPNs86RtUmIhIOfMqMxphuQKeMmyWB5dbaR/wW1aluuCEgT+tObvPeqjYRkXDi03SKtXa0tTbeWhsPLAY+8GtUIaRVliISSfI0R2GMqQiUs9Ym+ikej7i78OiPtq+9W9bOVgsOmvcWkfCV14nm7sDo0+80xnQFugJUqVIlj6fIzt2Cm8QdfzNzRZLLhTjg+Ty2VlmKSCQx1lrfHmhMIWAJcI3N4Uni4uJsYqL/BuqNhy1wWeYXZQxpLsIoGRPN8ZPpZ4ysh7atp8QsImHLGLPCWhuX23F5KTFsinNB07e/Aj5yd+HRVQIHSE5J1SpLEcm38pLEWwKL/BWIp9xdYIzycpcerbIUkfzA5yRure1vrZ3lz2A84W7BzV2NKru8v9TZ0S6fR9UmIpIfhPVin5yqTVzdH3dB6TPuB1RtIiL5Vtgm8dzavrq6KJnTQhxVm4hIfhS2STyn5e/eJmCtshSR/CpsG2Cp7auISO7CNolr+buISO7CNomr7auISO7Cdk5cy99FRHIXtkkcdEFSRCQ3YTudIiIiuVMSFxGJYEriIiIRTElcRCSCKYmLiEQwnzeF8PgExvwF7Mjj05QB9vkhHH8Kx5hAcXkjHGMCxeWtcIzLHzFdYK0tm9tBAU/i/mCMSfRkh4tgCseYQHF5IxxjAsXlrXCMK5gxaTpFRCSCKYmLiESwSEniY0MdgAvhGBMoLm+EY0yguLwVjnEFLaaImBMXERHXImUkLiIiLiiJi4hEsLBI4saYcsaYxbkcE22M+a8xZokx5kF39/k5rnHGmB+NMQNyOKabMSYh499qY8z7xpjCxpg/Trm/XgjichmDMeYlY8zPxphRIYiphDFmnjFmvjHmP8aYIoF8rzyM6YxjPHlcoGIK9nvkRVxB/XnyIq6g//5lnDfHnBXMfBXyJG6MKQVMAGJzOfQJYIW1tjHQ3hhT3M19/oqrLRBlrb0aqG6MqeXqOGvtaGttvLU2HlgMfABcCkzNvN9auy7YcbmKwRhzBdAEaAjsNca0CHJM9wBvWGtvBPYAN7mKM1gxuTrGi9cSkJgI4nvkZVxB+3nyJq5g//5lxOVJzgpavgp5EgfSgE7AoVyOiwc+y/h6ERDn5j5/OfW55+P8sLpljKkIlLPWJgJXAa2NMT9ljCT82bfd07hcxXAtMNM6V7O/BpoGMyZr7XvW2m8ybpYF9rqJM1gxuTrGk8cFLKYgv0cex+UmhkD9PHkTFxDU3z/wLGfFE6R8FfQknvFxJ/NjTgLwlLX2oAcPjQWSMr7+Gyjn5j5/xfWEl8/dHRid8fXPQAtrbUMgGrg5BHG5isEv71de3ytjzNVAKWvtMjdx+oMnrzWgP1M+xgQE7T3yJq6A/TzlMa5MAfn9c8Vae8iDnBW0n62g7+xjrX3Ex4ceAWKAg0CxjNuu7vNLXMaYkRnPTcZzu/2DZ4wpBFwHPJdx11pr7fGMrxMBnz+S5yEuVzFkvl+5PTZQMWGMKQ28A7TLIU5/8OS1ujrGL+9RHmIK5nvkTVwB+3nKY1wB/f3Lg4Dmq1OFw3SKp1bw78epy4Dtbu4L5PncaQost/8W3U80xlxmjIkC2gBrQhCXqxgC9X559LzGmCLAdKCftTazKVqg3itPYgq7n6kgv0cex+UmhkC+V57GBcH9/fNU8H62rLVh8Q9IOOXr5sDjp33/AuBXYCTOx6UoV/f5MZ5zcP7nvwFsAEoAFwODXRz7CtD2lNuXAGuBdcAQP79PHsXlKgacP9pLMt6vjUC1IMfUDTgAJGT86xSo98pFTJe5iMdV3GfcF+SYgvYeeRlX0H6evIkr47ig/f6ddt6EjP+GNF9F1IpNY0wFnL9kX9uMOSlX9/nxfKWAG4BF1to9/nzuvMhLXMaYGOAWYKW19vdwiClQPInJ1TGBfC3h+D6B73EF6ucpr3GFg2Dlq4hK4iIikl0kzYmLiMhplMRFRCKYkriISARTEhcRiWBK4iIiEez/AWfboa6BrkNGAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "model.eval()\n",
    "y_hat = model(x)\n",
    "plt.scatter(x.numpy(), y.numpy(), label='原始数据')\n",
    "plt.plot(x.numpy(), y_hat.detach().numpy(), c='r', label='拟合直线')\n",
    "# 显示图例\n",
    "plt.legend() \n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "查看参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('fc.weight', Parameter containing:\n",
       "  tensor([[2.7996]], requires_grad=True)), ('fc.bias', Parameter containing:\n",
       "  tensor([9.9626], requires_grad=True))]"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list(model.named_parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
